package org.javaee7.jaspic.invoke.sam;
import static java.util.logging.Level.SEVERE;
import static javax.security.auth.message.AuthStatus.SEND_SUCCESS;
import static javax.security.auth.message.AuthStatus.SUCCESS;
import java.io.IOException;
import java.util.Map;
import java.util.logging.Logger;
import javax.enterprise.inject.spi.CDI;
import javax.naming.InitialContext;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.message.AuthException;
import javax.security.auth.message.AuthStatus;
import javax.security.auth.message.MessageInfo;
import javax.security.auth.message.MessagePolicy;
import javax.security.auth.message.callback.CallerPrincipalCallback;
import javax.security.auth.message.callback.GroupPrincipalCallback;
import javax.security.auth.message.module.ServerAuthModule;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.javaee7.jaspic.invoke.bean.CDIBean;
import org.javaee7.jaspic.invoke.bean.EJBBean;
/**
*
* @author Arjan Tijms
*
*/
public class TestServerAuthModule implements ServerAuthModule {
private final static Logger logger = Logger.getLogger(TestServerAuthModule.class.getName());
private CallbackHandler handler;
private Class<?>[] supportedMessageTypes = new Class[] { HttpServletRequest.class, HttpServletResponse.class };
@Override
public void initialize(MessagePolicy requestPolicy, MessagePolicy responsePolicy, CallbackHandler handler,
@SuppressWarnings("rawtypes") Map options) throws AuthException {
this.handler = handler;
}
@Override
public AuthStatus validateRequest(MessageInfo messageInfo, Subject clientSubject, Subject serviceSubject) throws AuthException {
HttpServletRequest request = (HttpServletRequest) messageInfo.getRequestMessage();
HttpServletResponse response = (HttpServletResponse) messageInfo.getResponseMessage();
if ("cdi".equals(request.getParameter("tech"))) {
callCDIBean(request, response, "validateRequest");
} else if ("ejb".equals(request.getParameter("tech"))) {
callEJBBean(response, "validateRequest");
}
try {
handler.handle(new Callback[] {
new CallerPrincipalCallback(clientSubject, "test"),
new GroupPrincipalCallback(clientSubject, new String[] { "architect" })
});
return SUCCESS;
} catch (IOException | UnsupportedCallbackException e) {
throw (AuthException) new AuthException().initCause(e);
}
}
@Override
public Class<?>[] getSupportedMessageTypes() {
return supportedMessageTypes;
}
@Override
public AuthStatus secureResponse(MessageInfo messageInfo, Subject serviceSubject) throws AuthException {
HttpServletRequest request = (HttpServletRequest) messageInfo.getRequestMessage();
HttpServletResponse response = (HttpServletResponse) messageInfo.getResponseMessage();
if ("cdi".equals(request.getParameter("tech"))) {
callCDIBean(request, response, "secureResponse");
} else if ("ejb".equals(request.getParameter("tech"))) {
callEJBBean(response, "secureResponse");
}
return SEND_SUCCESS;
}
@Override
public void cleanSubject(MessageInfo messageInfo, Subject subject) throws AuthException {
HttpServletRequest request = (HttpServletRequest) messageInfo.getRequestMessage();
HttpServletResponse response = (HttpServletResponse) messageInfo.getResponseMessage();
if ("cdi".equals(request.getParameter("tech"))) {
callCDIBean(request, response, "cleanSubject");
} else if ("ejb".equals(request.getParameter("tech"))) {
callEJBBean(response, "cleanSubject");
}
}
private void callCDIBean(HttpServletRequest request, HttpServletResponse response, String phase) {
try {
CDIBean cdiBean = CDI.current().select(CDIBean.class).get();
response.getWriter().write(phase + ": " + cdiBean.getText() + "\n");
cdiBean.setTextViaInjectedRequest();
response.getWriter().write(phase + ": " + request.getAttribute("text")+ "\n");
} catch (Exception e) {
logger.log(SEVERE, "", e);
}
}
private void callEJBBean(HttpServletResponse response, String phase) {
try {
EJBBean ejbBean = (EJBBean) new InitialContext().lookup("java:module/EJBBean");
response.getWriter().write(phase + ": " + ejbBean.getText() + "\n");
} catch (Exception e) {
logger.log(SEVERE, "", e);
}
}
}